{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Visualization",
      "provenance": [],
      "collapsed_sections": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/sayakpaul/Training-BatchNorm-and-Only-BatchNorm/blob/master/Visualization_II.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mD10sl5PlAHc",
        "colab_type": "text"
      },
      "source": [
        "References:\n",
        "- https://machinelearningmastery.com/how-to-visualize-filters-and-feature-maps-in-convolutional-neural-networks/"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "lIYdn1woOS1n",
        "outputId": "282562d6-200e-49d5-9d0c-baeef8be646a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "import tensorflow as tf\n",
        "print(tf.__version__)"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "2.2.0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3mIUGzSI1Zqk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install wandb -q\n",
        "import wandb\n",
        "wandb.login()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ElQoPqrB33BY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!wget https://raw.githubusercontent.com/GoogleCloudPlatform/keras-idiomatic-programmer/master/zoo/resnet/resnet_cifar10.py"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gu3QWkkK37QC",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensorflow.keras.layers import *\n",
        "from tensorflow.keras.models import *\n",
        "from wandb.keras import WandbCallback\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "import resnet_cifar10\n",
        "import numpy as np\n",
        "import cv2\n",
        "import os\n",
        "\n",
        "# Random seed fixation\n",
        "tf.random.set_seed(666)\n",
        "np.random.seed(666)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zAiLfgr93ioB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "LABELS = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\",\n",
        "\t\"frog\", \"horse\", \"ship\", \"truck\"]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LwDvrzAl3Mab",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        },
        "outputId": "ad6f8611-370d-4364-b7e6-756856392c78"
      },
      "source": [
        "(_, _), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
        "x_test = x_test/255."
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\n",
            "170500096/170498071 [==============================] - 2s 0us/step\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h_DMb7o03G4j",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_training_model(optimizer=\"sgd\"):\n",
        "    # ResNet20\n",
        "    n = 2\n",
        "    depth =  n * 9 + 2\n",
        "    n_blocks = ((depth - 2) // 9) - 1\n",
        "\n",
        "    # The input tensor\n",
        "    inputs = Input(shape=(32, 32, 3))\n",
        "\n",
        "    # The Stem Convolution Group\n",
        "    x = resnet_cifar10.stem(inputs)\n",
        "\n",
        "    # The learner\n",
        "    x = resnet_cifar10.learner(x, n_blocks)\n",
        "\n",
        "    # The Classifier for 10 classes\n",
        "    outputs = resnet_cifar10.classifier(x, 10)\n",
        "\n",
        "    # Instantiate the Model\n",
        "    model = Model(inputs, outputs)\n",
        "\n",
        "    model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"])\n",
        "    \n",
        "    return model"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kHLLJSeA5lFv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_sample_test_image():\n",
        "    idx = np.random.choice(x_test.shape[0], 1)\n",
        "    plt.imshow(x_test[idx].squeeze(0))\n",
        "    plt.title(LABELS[int(y_test[idx])])\n",
        "    plt.show()\n",
        "    return idx"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pVZ_1X7458wT",
        "colab_type": "code",
        "outputId": "82db3f60-a711-4471-f738-4945fbee92d7",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 281
        }
      },
      "source": [
        "# Test\n",
        "sample_image_id = get_sample_test_image()\n",
        "sample_image = x_test[sample_image_id]"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAfkklEQVR4nO2de4xd13Xev3XuY97k8M3R8E3RUuTKkmlKsSI1iZ3YkI0GsositYMa/sMtgyIGaiBJIbhB7RRF6xS1DQMtHNCxECVQ7LiRXNux4FpRUjNOHFkUTVGUKJsUSYmkZjh8zvO+7+of99IYqfvbHHJm7qWzvx8wmDtnzT5n3X3Puufe/Z21lrk7hBD/+Mm67YAQojMo2IVIBAW7EImgYBciERTsQiSCgl2IRFCwC5EICnZx3ZjZKTP71W77Ia4PBbsQiaBgTxwz22xmT5jZeTO7aGb/w8x2mtlft/++YGaPmdlw+///FMAWAN8ysxkz+/fdfQZioZhul00XM8sBOAjgrwH8HoAGgD0AxgFsB7AfwAoAjwM46O6faI87BeBfu/tfdcFtcYPku+2A6Cr3ArgFwO+6e7297fvt38fbv8+b2ecAfKrTzomlRcGeNpsBvDov0AEAZrYBwBcA/FMAQ2h93bvceffEUqLv7GlzGsAWM3vzm/5/AeAA7nT3FQD+FQCbZ9d3v59BFOxp80MAYwA+Y2YDZtZrZvejdTWfATBpZqMAfvdN484B2NFZV8ViUbAnjLs3APwagFsBvAbgDIB/CeD3AewGMAng2wCeeNPQ/wrg98zsipn9Tuc8FotBq/FCJIKu7EIkgoJdiERQsAuRCAp2IRKhozfV5PM5LxQKQVvOLLgdAFV1DXyMR6Tg2JJkdMGSHi7ie8zH2KFiu4w9gdi4Tu3vGru8oVE3Olc3+gTosWKv5w0+azIsujfiR61WQaNeDxoXFexm9iBad1rlAPyRu38m9v+FQgG3btsatA325ui4LOw7ck0+pooGtTWpBag06tTm7HPQ/3dPyvyDhd/cAMAb3JN8LvKhy/k4di565KRvRk5Sj0YS99HJsEbGffcmf80Qmatixv3I5cKvTWR3UfJ5/lrX6/zcidEgzkT3ViwGN586fpQOueGP8e0kiv8J4H0A7gDwYTO740b3J4RYXhbznf1eAMfd/YS7VwF8FcBDS+OWEGKpWUywj6J1b/VVzrS3vQEz22tmB8zsQKMe+ZgmhFhWln013t33ufsed9+Ty/Pv2EKI5WUxwX4WrRTJq2xqbxNC3IQsZjX+WQC7zGw7WkH+IQC/ERvQ19OD23fdGrT93NaNdFyuUgluz5r8verS3DS1zZbC+2vZatRWqYVXTWt0mR4o1fhXF7YKCwC5yGp8Lsc/IVEZJ3KsZmR1vxkZ12jw51bzsK1U4Cv/+cjCfzGiCuRi6oSH56pKFB4gLq8NDg5SW0x6q1X52nqdnCM14+dAlagrWcbPjRsOdnevm9nHAfwftKS3R9z9xRvdnxBieVmUzu7uTwJ4col8EUIsI7pdVohEULALkQgKdiESQcEuRCJ0NOvNLENfoSdoG1k5RMdlpbCkYZFkhmI+fBwAmClweWIq49LbXCkskTQjiTCThSq11SJ3FDZiWV4R6a3RDEsyjYgkgyaXhcwiNvC5YrJoTF6ziHSVRZJdeiLJKTkiXzXBX5e4tMl9bDYjST4ReTDLwv4X8+FkFwAwhH3MInKdruxCJIKCXYhEULALkQgKdiESQcEuRCJ0djUeIGuIQFblK7tWDSeu1Go8oSUfeR8rRKp7xWxssuqRckqFiGRQbfAV4VipKyfJHa19ho/nhX46ph5xo1jkfmSxilWVsB899dj1JZKc0oiUwIooL1kvqXmY469ZrPRUtRpRV2r8HI4VlCsUws/NjQ9q1MnrHDmQruxCJIKCXYhEULALkQgKdiESQcEuRCIo2IVIhI5Kb+6OOqnFVStH9J9KWNIo13iSRqPQS22xZiCWiyQzkKSQWNefzLmPPTkuNVUi+4z5nycVfOvNMt9fbYbaPJLcsWqQz7H1hk+tmUtTdEwt0uEnK/JEKY/UACyXw/NvkbmP1daL1afLRxJoIvkpKBZJK6dYpx4ivcXQlV2IRFCwC5EICnYhEkHBLkQiKNiFSAQFuxCJ0FnpremoMRmtxKW3jGQ8lZ27X29EbJFsrXKT+1FBWJKpNiLZTpGMOItkeWWR9+FmbByR+laCS15Da8KZYQCwc9MGartl3SpqyxOBcGJ8go45fvoCtZ2+dIXamoVhaitXwvNfiEhvPT1cUoylr83MzFIbSWwDAOTyYR8rVX7u8MzH68/aXBBmdgrANIAGgLq771nM/oQQy8dSXNnf5e78LVkIcVOg7+xCJMJig90BfNfMnjOzvaF/MLO9ZnbAzA5UapFbYoUQy8piP8Y/4O5nzWw9gKfM7GV33z//H9x9H4B9ALB6xcrIHd9CiOVkUVd2dz/b/j0B4OsA7l0Kp4QQS88NX9nNbABA5u7T7cfvBfCfYmPcHTXyUb4Ukd4KpGdQI99Hx1TqPANprlyitlLkq0alES5wOTDEizluXs+lq3MXuZx0ZuIStUW6RmHVirBsdM/W9XTMxjWD1NabRQ5WGaMm1mFrzRYu842ObqW2I69OUtuhV85zP4iM5pEClvfcs5vaYlmA3/vefmorRIpYvuMddwW3/8NzL9AxWf36PyQv5mP8BgBfb6f85QH8mbt/ZxH7E0IsIzcc7O5+AkD4LUkIcdMh6U2IRFCwC5EICnYhEkHBLkQidDbrDU1UmmH56kqFFz3s97BcUyzy7KRqRCKxaiQDrHaR2v7J9luC22/ZtIWOyXmR2jat4D3K7tyyjtqOnTpJbds3hzPRhsGlvNokT21oZPx60BPJ6Osthl+z2Tn+OjcwTW33bwvPPQCszXE578CpsCw3fpH7saGXy6+bN41Q28G/5fvcc9d91Lb7rrcFtz/z7It0jGVE24wUxNSVXYhEULALkQgKdiESQcEuRCIo2IVIhA63fwJqpLVOuc7bJGWk5loTPKGlUuLtjnIVXivsF3bzO4BrtfA+py/yle7KHK9PV6mGlQkAWL2a11XbtZHbehFeSY50NEI146vZjVqkHVaTXysy0ubLCvyUy4HPR3OKJw3dNsqTfHrXjga3n5/i59uqSAm61ZGV+vfcdye17XjLdmqbuxRWQ7KI2tFkhRQjLaN0ZRciERTsQiSCgl2IRFCwC5EICnYhEkHBLkQidFR6AwwgLZvqNX4Df7URln9qlTk6pi/yzG5/K5dIZiMtd06NhRM1pktcMiqQdkxAvCXQ9ByXDgd6uFRmA+G6fDvuCCdbAMCK9Vy6ev7QIWqbmnid+0FaZfXnuK7VV+CJQfVcpB1WpE7eL9z/QHB7YXAjHXPu5EFqu3CWJyHtvGUltWWVy9Q2dSksBWf1iHzs4dc50tlMV3YhUkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkQkelt6Y7qpWwTMK2A0CB9BJqRFo1bdvO2y4Nr1xNbX/37HPUdqEUzmCbqUSktzqvq7ZxNZdqhoo8e6knUheuVA77MjzKs67uuJfXR1u95TZqe+47/5vapl87Htyer/LXrCfjslytl7fzesudt1PbuvVrgttfG+dSWGmOZ9jlncu9eZaJBmDliiFqu8ikQ5JlCQA5YzLlIrLezOwRM5swsyPztq02s6fM7Fj7d7jKoRDipmEhH+P/GMCDb9r2MICn3X0XgKfbfwshbmKuGeztfutvrs7wEIBH248fBfCBJfZLCLHE3Oh39g3ufrVf7zhaHV2DmNleAHsBoKfIa6gLIZaXRa/Gu7sjsirg7vvcfY+77ynk+T3dQojl5UaD/ZyZjQBA+/fE0rkkhFgObvRj/DcBfBTAZ9q/v7GgUQ7USQZbPVIRsU4yqDaNbqJjYt8YDh4+Qm3nZ7iMNuth+adu/BPLhg1cAhwY4FJTvhEpvmh8rmYrYWnryhyXvLyPS4Db7+LtjqqRIpBPnQxLb9Uql5OGmvx07BsNS2gAMLiBZ+298MLzwe1ZPx/TP9hPbU3nmXm3rOevdanMZbmB/nAGWz6S6WckjmIsRHr7CoAfALjNzM6Y2cfQCvL3mNkxAL/a/lsIcRNzzSu7u3+YmH5liX0RQiwjul1WiERQsAuRCAp2IRJBwS5EInS44KQDHpaNLJKtU2+EZaN169fSMa+/foLazlzkklGpySWSJpEAB/p4Rtb6DeuoDVWeQVWa5n3sega41JcnGXGnT5+iY+6K9HPLDYVlIQAYXMelzxkLn1ob1nEJbeYKl+XWjGyhtumICjU3G57jvkF+nRtexbMir5QvUlt+Bc8Hm7gwTm0VhF/PmkeKbLKebjyMdGUXIhUU7EIkgoJdiERQsAuRCAp2IRJBwS5EInRWejMgXwhLW339XE7aRHqRnXyN9906dZ5n3Zacy2sRtQP9+XDftls3ReS1Os9eK5V5McpiP3ekZ5jPVbUUlpomznIpcmyc92wrDPPssKrx1MJVW8IFLrfeuoOOObL/R9R2y+28P99ceZLaenvCGWxW5BmHg+tW8GNV+Gs2ZTxbbttdvKjnlRdfC26v57nsWSdyaUR505VdiFRQsAuRCAp2IRJBwS5EIijYhUiEjifCOMItlGKr8YVi+D1p7CRfRb4SXjgHANQjVW57wLMqNo2EEx36jCe0zJV5vbhmk9eFK/ZxHxsF/uS8HraVJ8/TMUePHKK2ddvfyv3I8dX4B977PjKGrxffepmrJKsi7auqY69Q28b1JFlndJSOGTsfrp8HALMNfn2sVXlC1MbRXdTWNx5+zZr5WC08fs4xdGUXIhEU7EIkgoJdiERQsAuRCAp2IRJBwS5EInRUems6MEPaPJ2Y4MkMZydmg9tnZ7hMVi3wp1ZwXutsx0ZeR2xVX1g2qld4vbhGk8sxuQJPxqhHMhqYvAYARhJv8pHkn8vHj1Jb7QpPKMp6uTRk9Y3hY5Vm6Jjb3vVr1Na/grdWGskibbQGwudbbXCQjrk4yc/FKedy6WSNz7Gd49Ln6I63BLevXjNEx5Smw3UUY7UcF9L+6REzmzCzI/O2fdrMzprZofbP+6+1HyFEd1nIx/g/BvBgYPvn3f3u9s+TS+uWEGKpuWawu/t+AJc64IsQYhlZzALdx83scPtjPv2ia2Z7zeyAmR2oR75rCiGWlxsN9i8C2AngbgBjAD7L/tHd97n7Hnffk893+FZ8IcRPuaFgd/dz7t5w9yaALwG4d2ndEkIsNTd0qTWzEXcfa//5QQBHYv9/FQdQQ1iKusTVMOQaYTmhWeeZYUZaRgHA5g091LZlFa/7NTMbbv1Ti8hkDdamB0AjUsOt4Py5eYVLjs1y2NZwfqzqJb4kM3v5ArWdnuZfy5gfF0u8Jt+a9RHZM1LfLctzyWuyJ+xjT5HPx9bb7qG2vlW85VUtC2d0AsDa9Vw6XJEPS4drhwfomPEGkXudnxvXDHYz+wqAXwaw1szOAPgUgF82s7vRit9TAH7zWvsRQnSXawa7u384sPnLy+CLEGIZ0e2yQiSCgl2IRFCwC5EICnYhEqGjd7kYgALCMklmXD4hah2cZNABQE+BSxBr1g5TW6nCpaGZmbCtb4hnJxXy3I96jR8LkWy56Wk+ricfzuaabfCXemWB+//qOM96+8HLvP3WYBaWN3uaXJ6qjW6ltuZaLkPViDQLACfOngpuX79pLR2zcWU4Yw8Abuvn505W4/pxgxRaBYCZalhGm5vlGYIekXQZurILkQgKdiESQcEuRCIo2IVIBAW7EImgYBciETqcYG7IkfcXMy4lUFvGpbdVK3jWWG8vf9pzJV48skIS6QqRtLd1Iyuo7fzFcNFAAKizgwHwKn+P7usJy2jViCxUneO2+twUP1aRy4qvvfhCcHvz/DgdszLP53Hl4D+jtv61XCpbQWTW46dfo2OsxmXP0QLPlps6znvm5fv4PmeGwtl+lTl+Luay8PltEQlbV3YhEkHBLkQiKNiFSAQFuxCJoGAXIhE6mwhjPOHFIrWz3MOr7pnxGmhrV/FV8NJcuJ0UAFQiK9Nm4RXQapWrAt7ktmppjh8rUjMui9SuuzwZTpKpXLlMx1Q8Ukvu3E+o7aGf/3lqq+0M12r7wV9+m475+28/xvdX5Sv1t7/z3dS2YddocHvfCn5+TI2do7YzL/+A2qovf4fahtetobb+u/55cPuqFSvpmHPgCUoMXdmFSAQFuxCJoGAXIhEU7EIkgoJdiERQsAuRCAvpCLMZwJ8A2IBWB5h97v4FM1sN4M8BbEOrK8yvuzvXd9Cqm9UgdeOMyGsAkDXDslwGLhk1ylzWulTl9cBAauQBwIqBcB00M+779BSX+QZ6eaup6SmeBGFF3r7q/KVwcs36PE/EwMXXqenk/ieprXHpFLUNNsPtmnYMcgntpPNWU/u/9XVqO3qM+/9LD/1ScPttb7uPjlk5wls1VaqruW2Sv2bTc1zCXLMyPFfbt2yhY469FK7/Z5HzdyFX9jqA33b3OwC8E8BvmdkdAB4G8LS77wLwdPtvIcRNyjWD3d3H3P1g+/E0gKMARgE8BODR9r89CuADy+WkEGLxXNd3djPbBuDtAJ4BsGFeJ9dxtD7mCyFuUhYc7GY2COBxAJ9w9zdUNPBWEevglzEz22tmB8zsQL3Ov2MLIZaXBQW7tW4KfxzAY+7+RHvzOTMbadtHgPDNuu6+z933uPuefL7DhXGEED/lmsFurTo3XwZw1N0/N8/0TQAfbT/+KIBvLL17QoilYiGX2vsBfATAC2Z2tcjWJwF8BsDXzOxjAF4F8OsLOWCDKAPe4B/xrRaWygoWyZSrRDK5+DBYgU9JA2GJrV7n7ZimJ/n76WAvr5PXV+S22TqXDiuN8JOrOpfeipEJ2bWGZ14Vpi5S2/kTLwW3b93M2y7dsytciw0Avvssr4V38uBBaht7/UfB7fe96wQd87bd76C2O0d4vbuJHD93htdxya5/KCzBRsRSFIn8GqtBd81gd/fvg4vPv3Kt8UKImwPdQSdEIijYhUgEBbsQiaBgFyIRFOxCJEJH73JpuqNcD7c1GiDFHAGgn8gJw729dEzO+ftYLDOo3uRZWbV6uBhlrcYz2/JZOKMJAPoiNxn19fHn9tpp3kIpPxiWtl4/x6WrTZHii/lmpG3RuUvUZjMkA2yaZ6htXBeRKSNFMWvGfbw0HpbYvvetP6Njzpw6w/f3lnAhTQDof4X7uH1HuA0VAJRJi62xs2fpmEY9LAM7P311ZRciFRTsQiSCgl2IRFCwC5EICnYhEkHBLkQidCHBPKwNZFwNwxoisa3v5YUXZ51nhjnJDGsZI6Z82Oh1PqjhYakRADDA/Z+bnKS2njqX+qYuhOXBUmWIjnl9hr/nz/yIZ4dlFd4TjRW4PDbD5alVCPdlA4DhAp/j1TmefddfDs/jhTI/9Z/92/3U9pN/4EVC7+cJcRif4r3Zfnz4D4PbTxznMmWdnXMR7U1XdiESQcEuRCIo2IVIBAW7EImgYBciETq8Gu/ISJunRoOvnvcjnCQzUOEtniaLkeV956vxhViSTCls6+nhddr6I3Xmxif4KvL6QZ7c8Z63bqW2C+fCLZT+7zHeoupMuUhtwxlffb5nC189f+v68Fy9/BOePHP8DD/WraP8VN09wq9ZL54Jt+z69ngkCcl4TcFypBz6kYt8JbxwmZ/f52ZPB7c7uFpjGfHftBovRPIo2IVIBAW7EImgYBciERTsQiSCgl2IRLim9GZmmwH8CVotmR3APnf/gpl9GsC/AXC+/a+fdPcno/uCIfOwJOOR951ZUm9rMlKXbLbJJbSsyKWmaGsoliXT5PLa+WleS65R4uPWFPlz2zHMJartA+HO2c+dvUzHTJe4LDcwxNsW5QvhpBsAWDk8GNy+bec6Oubky+epbVUPfz2bEXlwdPPO4PbBSzzRaChSd68eOU+9TuruAWhmXBIb6g/XAJyZ4RKgsYSXSCLXQnT2OoDfdveDZjYE4Dkze6pt+7y7//cF7EMI0WUW0uttDMBY+/G0mR0FIrmIQoibkuv6zm5m2wC8HcAz7U0fN7PDZvaImfEWnEKIrrPgYDezQQCPA/iEu08B+CKAnQDuRuvK/1kybq+ZHTCzA/VIW2YhxPKyoGA3swJagf6Yuz8BAO5+zt0b7t4E8CUA94bGuvs+d9/j7nvykf7VQojl5ZrBbq3u7l8GcNTdPzdv+8i8f/sggCNL754QYqlYyKX2fgAfAfCCmR1qb/skgA+b2d1oLfafAvCbCzlgRiQxB5c7poj01oxIbyXjT61A5D8AKOS5HFYohCW7Up3v73yJ+zFovC7c5cu89c/UeS6j9eXD2mGRPy34HJfeylVuOzvOsw4vrA63O5pr8vkoNSIS5hTPGjt9kUtUDVK/MFY30MsRySt2fbRIiyrEZOLw9not4iM18DELWY3/PhDM+4xq6kKImwvdQSdEIijYhUgEBbsQiaBgFyIRFOxCJEJH73IxABl7fzEuu8yRVk6NSIZPLXKzXlbmxj5S3BIApsthqalc4alymfP3U4u81U6VufGHP+bZYetIRlyW55lhWZFP5GxEohqLtHI6+HK4NVRuIFwAEgBmnWcI1ievUFveefZd76rw69mIZEV6g8uNsTZl+R5eILLhPNSaxJd8jp8Dler1342qK7sQiaBgFyIRFOxCJIKCXYhEULALkQgKdiESoaPSmzvQoAUnedZbmfSHq0ZkLTQjtiqXkyoNnvHkJKMob1yPGaxzyShW3LKc4xlxL01zH4eb4cy8mSaXk6qR9/x6RN9s1HnhzjNT4XFZ5ElfqvHTsbfJZS2f5fNRr8wGt1+pRqSwLJIiGNFLq5FzuBKZR8+Fj2cRna9Ier1ZxD9d2YVIBAW7EImgYBciERTsQiSCgl2IRFCwC5EInZXeALAaek3jsoXlwrILk/EAIIs8tWa0oRv3I1iJD0CrmnaYYlaltplIocpqPtwrDQDKEVuJFfRszNAxlRx/zs3I9SDr6ae2SdLb7Mokl8mmCiupbSDHs7zKXFXEXO36pV5k3BZ7reuRHnG1iIxWI+dxLpLVWamH56MZafamK7sQiaBgFyIRFOxCJIKCXYhEULALkQjXXI03s14A+wH0tP//L9z9U2a2HcBXAawB8ByAj7g7X3oGYFmGXA9Jnoisgmfkpv9mrAZdJPGgUOC1znJ5PiU1sgJajdQsm4k8r2oPP5bnuY9ufEV4thZe7TYmJQAoR1aKLdI3qkoSOADgcjVcF24SPHlmxvlc1SPJKb0Zv2bN5cLHa0ZaKxUideaMnIsAUIrMcUxtYm3MYglWNfKUIyGxoCt7BcC73f0utNozP2hm7wTwBwA+7+63ArgM4GML2JcQoktcM9i9xVWRttD+cQDvBvAX7e2PAvjAsngohFgSFtqfPdfu4DoB4CkArwC44v7Tzx9nAIwuj4tCiKVgQcHu7g13vxvAJgD3Arh9oQcws71mdsDMDtTJd14hxPJzXavx7n4FwN8AuA/AsNlPm6BvAhBsKO7u+9x9j7vvyUcWv4QQy8s1g93M1pnZcPtxH4D3ADiKVtD/i/a/fRTAN5bLSSHE4lnIpXYEwKNmlkPrzeFr7v6XZvYSgK+a2X8G8CMAX77WjhyAkxpZ5Wq4xRMANHJhCaIJLnlF1DDUIvXYchHJrkZ6SnkksWYu0nYp1r7KGlzF7HGeTJJrhm2NiCgTrtLWJuPHKkfkpFw17H+VSGEAMNPk50C5yecjn/H5r7KklkasZRf/uplFZL5mJDELERmtXi8Ft+cj2jJ7Wk1SJxFYQLC7+2EAbw9sP4HW93chxM8AuoNOiERQsAuRCAp2IRJBwS5EIijYhUgEYy2NluVgZucBvNr+cy2ACx07OEd+vBH58UZ+1vzY6u7rQoaOBvsbDmx2wN33dOXg8kN+JOiHPsYLkQgKdiESoZvBvq+Lx56P/Hgj8uON/KPxo2vf2YUQnUUf44VIBAW7EInQlWA3swfN7MdmdtzMHu6GD20/TpnZC2Z2yMwOdPC4j5jZhJkdmbdttZk9ZWbH2r9XdcmPT5vZ2facHDKz93fAj81m9jdm9pKZvWhm/669vaNzEvGjo3NiZr1m9kMze77tx++3t283s2facfPnZsbzhUO4e0d/AOTQqmG3A0ARwPMA7ui0H21fTgFY24Xj/iKA3QCOzNv23wA83H78MIA/6JIfnwbwOx2ejxEAu9uPhwD8BMAdnZ6TiB8dnRO0WogOth8XADwD4J0AvgbgQ+3tfwjg317PfrtxZb8XwHF3P+GtOvNfBfBQF/zoGu6+H8ClN21+CK0qvUCHqvUSPzqOu4+5+8H242m0KiGNosNzEvGjo3iLJa/o3I1gHwVwet7f3axM6wC+a2bPmdneLvlwlQ3uPtZ+PA5gQxd9+biZHW5/zF/2rxPzMbNtaBVLeQZdnJM3+QF0eE6Wo6Jz6gt0D7j7bgDvA/BbZvaL3XYIaL2zI97cYzn5IoCdaDUEGQPw2U4d2MwGATwO4BPuPjXf1sk5CfjR8TnxRVR0ZnQj2M8C2Dzvb1qZdrlx97Pt3xMAvo7ultk6Z2YjAND+PdENJ9z9XPtEawL4Ejo0J2ZWQCvAHnP3J9qbOz4nIT+6NSftY193RWdGN4L9WQC72iuLRQAfAvDNTjthZgNmNnT1MYD3AjgSH7WsfBOtKr1AF6v1Xg2uNh9EB+bEzAytgqVH3f1z80wdnRPmR6fnZNkqOndqhfFNq43vR2ul8xUA/6FLPuxASwl4HsCLnfQDwFfQ+jhYQ+u718fQapD5NIBjAP4KwOou+fGnAF4AcBitYBvpgB8PoPUR/TCAQ+2f93d6TiJ+dHROALwNrYrNh9F6Y/mP887ZHwI4DuB/Aei5nv3qdlkhEiH1BTohkkHBLkQiKNiFSAQFuxCJoGAXIhEU7EIkgoJdiET4f9EFpb7UpvLoAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gDYzJEOKhnYQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def del_weights():\n",
        "    if os.path.exists(\"model-best.h5\"):\n",
        "        !rm -rf model-best.h5\n",
        "        print(\"Previous weights deleted!\")\n",
        "\n",
        "del_weights()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Qhxqh4gU3swU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Load the weights\n",
        "resnet_adam = get_training_model(\"adam\")\n",
        "resnet_adam_weights = wandb.restore(\"model-best.h5\", run_path=\"sayakpaul/training-bn-only/resnet-ramups-adam\")\n",
        "resnet_adam.load_weights(resnet_adam_weights.name)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WO-YsgY5GRZp",
        "colab_type": "code",
        "outputId": "8a0039b4-4037-4038-c566-b093a429be76",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 391
        }
      },
      "source": [
        "# Inspect the output shapes of the conv layers\n",
        "conv_ids = []\n",
        "for i in range(len(resnet_adam.layers)):\n",
        "\tlayer = resnet_adam.layers[i]\n",
        "\t# check for convolutional layer\n",
        "\tif \"conv\" not in layer.name:\n",
        "\t\tcontinue\n",
        "\t# summarize output shape\n",
        "\tprint(i, layer.name, layer.output.shape)\n",
        "\tconv_ids.append(i)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "1 conv2d_166 (None, 32, 32, 16)\n",
            "4 conv2d_168 (None, 32, 32, 16)\n",
            "7 conv2d_169 (None, 32, 32, 16)\n",
            "10 conv2d_170 (None, 32, 32, 64)\n",
            "11 conv2d_167 (None, 32, 32, 64)\n",
            "15 conv2d_171 (None, 32, 32, 16)\n",
            "18 conv2d_172 (None, 32, 32, 16)\n",
            "21 conv2d_173 (None, 32, 32, 64)\n",
            "25 conv2d_175 (None, 32, 32, 64)\n",
            "28 conv2d_176 (None, 16, 16, 64)\n",
            "31 conv2d_177 (None, 16, 16, 128)\n",
            "32 conv2d_174 (None, 16, 16, 128)\n",
            "36 conv2d_178 (None, 16, 16, 64)\n",
            "39 conv2d_179 (None, 16, 16, 64)\n",
            "42 conv2d_180 (None, 16, 16, 128)\n",
            "46 conv2d_182 (None, 16, 16, 128)\n",
            "49 conv2d_183 (None, 8, 8, 128)\n",
            "52 conv2d_184 (None, 8, 8, 256)\n",
            "53 conv2d_181 (None, 8, 8, 256)\n",
            "57 conv2d_185 (None, 8, 8, 128)\n",
            "60 conv2d_186 (None, 8, 8, 128)\n",
            "63 conv2d_187 (None, 8, 8, 256)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xS4ozcgpaAVY",
        "colab_type": "code",
        "outputId": "9d17e58d-4bfd-465f-cbd1-ee5ed5ac5ab5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 527
        }
      },
      "source": [
        "# conv layer at 10th index\n",
        "model = Model(inputs=resnet_adam.inputs, outputs=resnet_adam.layers[10].output)\n",
        "model.summary()"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"model_2\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "input_2 (InputLayer)         [(None, 32, 32, 3)]       0         \n",
            "_________________________________________________________________\n",
            "conv2d_166 (Conv2D)          (None, 32, 32, 16)        448       \n",
            "_________________________________________________________________\n",
            "batch_normalization_164 (Bat (None, 32, 32, 16)        64        \n",
            "_________________________________________________________________\n",
            "re_lu_164 (ReLU)             (None, 32, 32, 16)        0         \n",
            "_________________________________________________________________\n",
            "conv2d_168 (Conv2D)          (None, 32, 32, 16)        272       \n",
            "_________________________________________________________________\n",
            "batch_normalization_165 (Bat (None, 32, 32, 16)        64        \n",
            "_________________________________________________________________\n",
            "re_lu_165 (ReLU)             (None, 32, 32, 16)        0         \n",
            "_________________________________________________________________\n",
            "conv2d_169 (Conv2D)          (None, 32, 32, 16)        2320      \n",
            "_________________________________________________________________\n",
            "batch_normalization_166 (Bat (None, 32, 32, 16)        64        \n",
            "_________________________________________________________________\n",
            "re_lu_166 (ReLU)             (None, 32, 32, 16)        0         \n",
            "_________________________________________________________________\n",
            "conv2d_170 (Conv2D)          (None, 32, 32, 64)        1088      \n",
            "=================================================================\n",
            "Total params: 4,320\n",
            "Trainable params: 4,224\n",
            "Non-trainable params: 96\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9DRiFWUEW5bB",
        "colab_type": "code",
        "outputId": "51e643c8-ec92-462d-ddbf-c7e36941aca6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# Extending to multiple images\n",
        "few_images = []\n",
        "idx = np.random.choice(x_test.shape[0], 8)\n",
        "for single_image in idx:\n",
        "    few_images.append(x_test[single_image])\n",
        "few_images = np.array(few_images)\n",
        "\n",
        "few_images.shape"
      ],
      "execution_count": 15,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(8, 32, 32, 3)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2C8J4wZScYLv",
        "colab_type": "code",
        "outputId": "857257d8-ea67-4b09-d326-1bc7b91ce38f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "# Extract the feature maps\n",
        "feature_maps_multi = model.predict(few_images)\n",
        "feature_maps_multi.shape"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(8, 32, 32, 64)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nysg4xmvemTE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        },
        "outputId": "d2cd374f-250d-4ecf-f85d-7485b56bfa6b"
      },
      "source": [
        "wandb.init(project=\"training-bn-only\", id=\"visualization-resnet\")"
      ],
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
              "                Project page: <a href=\"https://app.wandb.ai/sayakpaul/training-bn-only\" target=\"_blank\">https://app.wandb.ai/sayakpaul/training-bn-only</a><br/>\n",
              "                Run page: <a href=\"https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-resnet\" target=\"_blank\">https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-resnet</a><br/>\n",
              "            "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "W&B Run: https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-resnet"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I39E62MRceoP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Visualize\n",
        "square = 8\n",
        "ix = 1\n",
        "fig = plt.figure(figsize=(16, 16))\n",
        "for (i, feature_map) in zip(range(square), feature_maps_multi):\n",
        "    for j in range(square):\n",
        "        ax = plt.subplot(square, square, ix)\n",
        "        ax.set_xticks([])\n",
        "        ax.set_yticks([])\n",
        "        if j == 0:\n",
        "            plt.imshow(few_images[i].squeeze())\n",
        "        else:\n",
        "            if np.ndim(feature_map) != 4:\n",
        "                feature_map = np.expand_dims(feature_map, 0)\n",
        "            plt.imshow(feature_map[0, :, :, ix-1], cmap=\"binary\")\n",
        "        ix += 1\n",
        "\n",
        "wandb.log({\"10th_conv_filter\": fig})\n",
        "plt.close()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0P6vAsqOf-2R",
        "colab_type": "text"
      },
      "source": [
        "**Note**: One may need to delete the previous model weights first. The most preferred way here is to manually download the model weights from [this run page](https://app.wandb.ai/sayakpaul/training-bn-only/runs/bn-only-adam/files?workspace=user-sayakpaul) and upload the file to Colab. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9AFW-9-Ij4sF",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Load the weights of BN only ResNet\n",
        "# del_weights()\n",
        "resnet_bn_adam = get_training_model(\"adam\")\n",
        "# resnet_adam_bn_weights = wandb.restore(\"model-best.h5\", run_path=\"sayakpaul/training-bn-only/bn-only-adam\")\n",
        "resnet_bn_adam.load_weights(\"model-best.h5\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xPga7V_Qg9vr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        },
        "outputId": "3d9341be-e873-47ff-c543-62ee3b79d200"
      },
      "source": [
        "wandb.init(project=\"training-bn-only\", id=\"visualization-bn-only\")"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "                Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
              "                Project page: <a href=\"https://app.wandb.ai/sayakpaul/training-bn-only\" target=\"_blank\">https://app.wandb.ai/sayakpaul/training-bn-only</a><br/>\n",
              "                Run page: <a href=\"https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-bn-only\" target=\"_blank\">https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-bn-only</a><br/>\n",
              "            "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "W&B Run: https://app.wandb.ai/sayakpaul/training-bn-only/runs/visualization-bn-only"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CBXg7QSGkW3R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# conv layer at 10th index\n",
        "model = Model(inputs=resnet_bn_adam.inputs, outputs=resnet_bn_adam.layers[10].output)\n",
        "\n",
        "# Extract the feature maps\n",
        "feature_maps_multi = model.predict(few_images)\n",
        "\n",
        "# Visualize\n",
        "square = 8\n",
        "ix = 1\n",
        "fig = plt.figure(figsize=(16, 16))\n",
        "for (i, feature_map) in zip(range(square), feature_maps_multi):\n",
        "    for j in range(square):\n",
        "        ax = plt.subplot(square, square, ix)\n",
        "        ax.set_xticks([])\n",
        "        ax.set_yticks([])\n",
        "        if j == 0:\n",
        "            plt.imshow(few_images[i].squeeze())\n",
        "        else:\n",
        "            if np.ndim(feature_map) != 4:\n",
        "                feature_map = np.expand_dims(feature_map, 0)\n",
        "            plt.imshow(feature_map[0, :, :, ix-1], cmap=\"binary\")\n",
        "        ix += 1\n",
        "\n",
        "wandb.log({\"10th_conv_filter\": fig})\n",
        "plt.close()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}